{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "xla_flags = os.environ.get(\"XLA_FLAGS\", \"\")\n",
    "xla_flags += \" --xla_gpu_triton_gemm_any=True\"\n",
    "os.environ[\"XLA_FLAGS\"] = xla_flags\n",
    "os.environ[\"MUJOCO_GL\"] = \"egl\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import functools\n",
    "import re\n",
    "from datetime import datetime\n",
    "from typing import Tuple\n",
    "\n",
    "import jax\n",
    "import jax.numpy as jp\n",
    "import matplotlib.pyplot as plt\n",
    "import mediapy as media\n",
    "import numpy as np\n",
    "from brax.training.agents.sac import train as sac\n",
    "from IPython.display import clear_output, display\n",
    "\n",
    "from mujoco_playground import dm_control_suite as suite\n",
    "from mujoco_playground import wrapper\n",
    "\n",
    "# Enable persistent compilation cache.\n",
    "jax.config.update(\"jax_compilation_cache_dir\", \"/tmp/jax_cache\")\n",
    "jax.config.update(\"jax_persistent_cache_min_entry_size_bytes\", -1)\n",
    "jax.config.update(\"jax_persistent_cache_min_compile_time_secs\", 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "CAMERAS = {\n",
    "    \"AcrobotSwingup\": \"fixed\",\n",
    "    \"BallInCup\": \"cam0\",\n",
    "    \"CartpoleSwingup\": \"fixed\",\n",
    "    \"CartpoleBalance\": \"fixed\",\n",
    "    \"CheetahRun\": \"side\",\n",
    "    \"HumanoidRun\": \"side\",\n",
    "    \"HumanoidStand\": \"side\",\n",
    "    \"HumanoidWalk\": \"side\",\n",
    "    \"PointMass\": \"cam0\",\n",
    "    \"WalkerRun\": \"side\",\n",
    "    \"WalkerWalk\": \"side\",\n",
    "    \"WalkerStand\": \"side\",\n",
    "    \"HopperHop\": \"cam0\",\n",
    "    \"HopperStand\": \"cam0\",\n",
    "    \"FishSwim\": \"fixed_top\",\n",
    "    \"ReacherEasy\": \"fixed\",\n",
    "    \"ReacherHard\": \"fixed\",\n",
    "    \"Swimmer6\": \"tracking1\",\n",
    "    \"BarkourJoystick\": \"track\",\n",
    "    \"PendulumSwingup\": \"fixed\",\n",
    "    \"FingerSpin\": \"cam0\",\n",
    "    \"FingerTurnEasy\": \"cam0\",\n",
    "    \"FingerTurnHard\": \"cam0\",\n",
    "    \"DogStand\": \"y-axis\",\n",
    "}\n",
    "\n",
    "DISCOUNTS = {\n",
    "    \"HopperHop\": 0.99,\n",
    "    \"CartpoleBalance\": 0.99,\n",
    "    \"CartpoleSwingup\": 0.99,\n",
    "    \"HumanoidRun\": 0.95,\n",
    "    \"AcrobotSwingup\": 0.99,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "env_name = \"WalkerWalk\"\n",
    "env_cfg = suite.get_default_config(env_name)\n",
    "env = suite.load(env_name, config=env_cfg)\n",
    "print(env_cfg)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Unit Test Env"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "jit_reset = jax.jit(env.reset)\n",
    "jit_step = jax.jit(env.step)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "key = jax.random.PRNGKey(0)\n",
    "state = jit_reset(key)\n",
    "\n",
    "action = jp.zeros((env.action_size,))\n",
    "state = jit_step(state, action)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "key = jax.random.PRNGKey(12345)\n",
    "key, key_reset = jax.random.split(key)\n",
    "state = jit_reset(key_reset)\n",
    "states = [state]\n",
    "actions = []\n",
    "f = 0.5\n",
    "for i in range(env_cfg.episode_length):\n",
    "  key, key_action = jax.random.split(key)\n",
    "  action = []\n",
    "  for j in range(env.action_size):\n",
    "    action.append(\n",
    "        jp.sin(\n",
    "            state.data.time * 2 * jp.pi * f + j * 2 * jp.pi / env.action_size\n",
    "        )\n",
    "    )\n",
    "  action = jp.array(action)\n",
    "  actions.append(action)\n",
    "  state = jit_step(state, action)\n",
    "  states.append(state)\n",
    "frames = env.render(states, camera=CAMERAS[env_name])\n",
    "media.show_video(frames, fps=1.0 / env.dt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dm_control import suite as dmc_suite\n",
    "\n",
    "\n",
    "def env_name_to_domain_and_task_name(env_name: str) -> Tuple[str, str]:\n",
    "  domain_name, *task_name = re.split(\"(?<=.)(?=[A-Z])\", env_name)\n",
    "  if len(task_name) == 1:\n",
    "    task_name = task_name[0]\n",
    "  else:\n",
    "    task_name = \"_\".join(task_name)\n",
    "  return domain_name.lower(), task_name.lower()\n",
    "\n",
    "\n",
    "domain_name, task_name = env_name_to_domain_and_task_name(env_name)\n",
    "dmc_env = dmc_suite.load(domain_name, task_name, task_kwargs={\"random\": 0})\n",
    "\n",
    "action_spec = dmc_env.action_spec()\n",
    "\n",
    "frames = []\n",
    "timestep = dmc_env.reset()\n",
    "f = 0.5\n",
    "while not timestep.last():\n",
    "  action = []\n",
    "  for i in range(action_spec.shape[0]):\n",
    "    action.append(\n",
    "        np.sin(\n",
    "            dmc_env.physics.time() * 2 * np.pi * f\n",
    "            + i * 2 * np.pi / action_spec.shape[0]\n",
    "        )\n",
    "    )\n",
    "  action = np.array(action)\n",
    "  timestep = dmc_env.step(action)\n",
    "  frames.append(dmc_env.physics.render(camera_id=CAMERAS[env_name]))\n",
    "media.show_video(frames, fps=1.0 / dmc_env.control_timestep())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "key = jax.random.PRNGKey(0)\n",
    "frames = []\n",
    "for _ in range(5):\n",
    "  key, key_reset = jax.random.split(key)\n",
    "  state = jit_reset(key_reset)\n",
    "  frames.append(env.render(state, camera=CAMERAS[env_name]))\n",
    "media.show_image(np.hstack(frames))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "frames = []\n",
    "for _ in range(5):\n",
    "  frames.append(dmc_env.physics.render(camera_id=CAMERAS[env_name]))\n",
    "media.show_image(np.hstack(frames))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "kwargs = {\n",
    "    \"num_timesteps\": 4_000_000,\n",
    "    \"num_evals\": 10,\n",
    "    \"reward_scaling\": 1.0,\n",
    "    \"episode_length\": env_cfg.episode_length,\n",
    "    \"normalize_observations\": True,\n",
    "    \"action_repeat\": 1,\n",
    "    \"discounting\": DISCOUNTS.get(env_name, 0.95),\n",
    "    \"learning_rate\": 1e-3,\n",
    "    \"num_envs\": 128,\n",
    "    \"batch_size\": 512,\n",
    "    \"grad_updates_per_step\": 8,\n",
    "    \"max_replay_size\": 1048576 * 4,\n",
    "    \"min_replay_size\": 8192,\n",
    "    \"seed\": 0,\n",
    "    \"max_devices_per_host\": 1,\n",
    "}\n",
    "\n",
    "x_data, y_data, y_dataerr = [], [], []\n",
    "times = [datetime.now()]\n",
    "\n",
    "\n",
    "def progress(num_steps, metrics):\n",
    "  clear_output(wait=True)\n",
    "\n",
    "  times.append(datetime.now())\n",
    "  x_data.append(num_steps)\n",
    "  y_data.append(metrics[\"eval/episode_reward\"])\n",
    "  y_dataerr.append(metrics[\"eval/episode_reward_std\"])\n",
    "\n",
    "  plt.xlim([0, kwargs[\"num_timesteps\"] * 1.25])\n",
    "  plt.ylim([0, 1100])\n",
    "  plt.xlabel(\"# environment steps\")\n",
    "  plt.ylabel(\"reward per episode\")\n",
    "  plt.title(f\"y={y_data[-1]:.3f}\")\n",
    "  plt.errorbar(x_data, y_data, yerr=y_dataerr, color=\"blue\")\n",
    "\n",
    "  display(plt.gcf())\n",
    "\n",
    "\n",
    "train_fn = functools.partial(sac.train, **kwargs, progress_fn=progress, wrap_env_fn=wrapper.wrap_for_brax_training)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "make_inference_fn, params, metrics = train_fn(environment=env)\n",
    "print(f\"time to jit: {times[1] - times[0]}\")\n",
    "print(f\"time to train: {times[-1] - times[1]}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for k, v in metrics.items():\n",
    "  print(f\"{k}: {v}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "jit_reset = jax.jit(env.reset)\n",
    "jit_step = jax.jit(env.step)\n",
    "jit_inference_fn = jax.jit(make_inference_fn(params, deterministic=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "rng = jax.random.PRNGKey(2)\n",
    "rollout = []\n",
    "n_episodes = 1\n",
    "\n",
    "for _ in range(n_episodes):\n",
    "  state = jit_reset(rng)\n",
    "  rollout.append(state)\n",
    "  for i in range(env_cfg.episode_length):\n",
    "    act_rng, rng = jax.random.split(rng)\n",
    "    ctrl, _ = jit_inference_fn(state.obs, act_rng)\n",
    "    state = jit_step(state, ctrl)\n",
    "    rollout.append(state)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "render_every = 2\n",
    "frames = env.render(rollout[::render_every], camera=CAMERAS[env_name])\n",
    "rewards = [s.reward for s in rollout]\n",
    "media.show_video(frames, fps=1.0 / env.dt / render_every)\n",
    "media.write_video(f\"{env_name}.mp4\", frames, fps=1.0 / env.dt / render_every)\n",
    "\n",
    "plt.plot(np.convolve(rewards, np.ones(100) / 100, mode=\"valid\"))\n",
    "plt.xlabel(\"time step\")\n",
    "plt.ylabel(\"reward\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
